{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "74083d2a-cf45-4bae-be70-57eefa5de105",
   "metadata": {},
   "source": [
    "# HF Transformers 核心模块学习：Pipelines 进阶\n",
    "\n",
    "我们已经学习了 Pipeline API 针对各类任务的基本使用。\n",
    "\n",
    "实际上，在 Transformers 库内部实现中，Pipeline 作为管理：`原始文本-输入Token IDs-模型推理-输出概率-生成结果` 的流水线抽象，背后串联了 Transformers 库的核心模块 `Tokenizer`和 `Models`。\n",
    "\n",
    "![](docs/images/pipeline_advanced.png)\n",
    "\n",
    "下面我们开始结合大语言模型（在 Transformers 中也是一种特定任务）学习：\n",
    "\n",
    "- 使用 Pipeline 如何与现代的大语言模型结合，以完成各类下游任务\n",
    "- 使用 Tokenizer 编解码文本\n",
    "- 使用 Models 加载和保存模型"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "9dc4e96d-49a1-4ac2-843f-e9e21af9602c",
   "metadata": {},
   "source": [
    "## 使用 Pipeline 调用大语言模型\n",
    "\n",
    "### Language Modeling\n",
    "\n",
    "语言建模是一项预测文本序列中的单词的任务。它已经成为非常流行的自然语言处理任务，因为预训练的语言模型可以用于许多其他下游任务的微调。最近，对大型语言模型（LLMs）产生了很大兴趣，这些模型展示了零或少量样本学习能力。这意味着该模型可以解决其未经明确训练过的任务！虽然语言模型可用于生成流畅且令人信服的文本，但需要小心使用，因为文本可能并不总是准确无误。\n",
    "\n",
    "通过理论篇学习，我们了解到有两种典型的语言模型：\n",
    "\n",
    "- 自回归：模型目标是预测序列中的下一个 Token（文本），训练时对下文进行了掩码。如：GPT-3。\n",
    "- 自编码：模型目标是理解上下文后，补全句子中丢失/掩码的 Token（文本）。如：BERT。\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "501715fc-ec83-4f15-9f85-800e2dfd9398",
   "metadata": {},
   "source": [
    "### 使用 GPT-2 实现文本生成\n",
    "\n",
    "![](docs/images/gpt2.png)\n",
    "\n",
    "模型主页：https://huggingface.co/gpt2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "6e27276b-c488-4205-a94a-e77067aee5db",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "[{'generated_text': \"Hugging Face is a community-based open-source platform for machine learning. It combines Python's fast performance in a very compact form and makes it easy to learn a complex field of machine learning.\"}]"
      ]
     },
     "execution_count": 1,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "from transformers import pipeline\n",
    "\n",
    "prompt = \"Hugging Face is a community-based open-source platform for machine learning.\"\n",
    "generator = pipeline(task=\"text-generation\", model=\"gpt2\")\n",
    "generator(prompt)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "71218c39-488f-40f0-9ebd-64e00a66b825",
   "metadata": {},
   "source": [
    "#### 设置文本生成返回条数"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "960eec18-cd84-4605-b0d5-a52191c278e4",
   "metadata": {},
   "outputs": [],
   "source": [
    "prompt = \"You are very smart\"\n",
    "generator = pipeline(task=\"text-generation\", model=\"gpt2\", num_return_sequences=3)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "09fcddb8-bd6b-4194-add6-f63077ff0192",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "[{'generated_text': \"You are very smart, he is your best friend. Do you know my secret? Do the test. I always do. No. I can't remember you. Oh. Sorry. Your mother is dead, does she still still drive? No.\"},\n",
       " {'generated_text': \"You are very smart and you've got big hands. People like to play the game to the end. But when they get on the bus, they can take their kids. When your kid runs away, they don't care. They keep to themselves\"},\n",
       " {'generated_text': 'You are very smart\" and \"good for business\" will not be found anywhere: \"You are smart because you love your boss or friend.\" This sentiment is only expressed by people who love to hear they are, in fact, employees of company,'}]"
      ]
     },
     "execution_count": 3,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "generator(prompt)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "7ae18861-4ac3-4df2-8716-c0bde84fe601",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "[{'generated_text': 'You are very smart young boy who is being bullied but I want you to realize that a teacher is my best friend and I feel comfortable telling him my story. Thank you so much.\"'},\n",
       " {'generated_text': 'You are very smart.\"\\n\\nHe said I don\\'t think he\\'s being disrespectful. One of the most important things he\\'s asked me about is how he feels, what it\\'s like to be the person you were.\"\\n\\nMoss was'}]"
      ]
     },
     "execution_count": 4,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "generator(prompt, num_return_sequences=2)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "be167b91-898f-4840-9da9-3646e91bfdbf",
   "metadata": {},
   "source": [
    "#### 设置文本生成最大长度"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "496ffdf4-a3a6-4396-92ae-47b4784766dc",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "[{'generated_text': 'You are very smart, are your hands so good, why not make a movie'},\n",
       " {'generated_text': 'You are very smart and incredibly well supported in the community and I was very nervous'}]"
      ]
     },
     "execution_count": 5,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "generator(prompt, num_return_sequences=2, max_length=16)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "37bad445-bbd9-49a0-875b-624711263c00",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "markdown",
   "id": "e1b519e4-2a3b-47d8-98f2-96d8e9980b93",
   "metadata": {},
   "source": [
    "### 使用 BERT-Base-Chinese 实现中文补全\n",
    "\n",
    "![](docs/images/bert-base-chinese.png)\n",
    "\n",
    "模型主页：https://huggingface.co/bert-base-chinese"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "f1a09f2a-17a3-436a-b98d-3f65ef68bf47",
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "Some weights of the model checkpoint at bert-base-chinese were not used when initializing BertForMaskedLM: ['bert.pooler.dense.weight', 'cls.seq_relationship.weight', 'bert.pooler.dense.bias', 'cls.seq_relationship.bias']\n",
      "- This IS expected if you are initializing BertForMaskedLM from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).\n",
      "- This IS NOT expected if you are initializing BertForMaskedLM from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).\n"
     ]
    }
   ],
   "source": [
    "from transformers import pipeline\n",
    "\n",
    "fill_mask = pipeline(task=\"fill-mask\", model=\"bert-base-chinese\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "id": "ac88584b-6826-4217-8b45-47c7a09be121",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[{'score': 0.9203723669052124,\n",
       "  'token': 679,\n",
       "  'token_str': '不',\n",
       "  'sequence': '人 民 是 不 可 战 胜 的'}]"
      ]
     },
     "execution_count": 7,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "text = \"人民是[MASK]可战胜的\"\n",
    "\n",
    "fill_mask(text, top_k=1)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "00e2b95b-4a9a-44ab-a065-fb9e1f1dd949",
   "metadata": {},
   "source": [
    "#### 设置文本补全的条数"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "id": "a20e3400-6ea1-47a7-b90f-5ef900c0f16d",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[{'score': 0.7596911191940308,\n",
       "  'token': 8043,\n",
       "  'token_str': '？',\n",
       "  'sequence': '美 国 的 首 都 是 ？'}]"
      ]
     },
     "execution_count": 8,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "text = \"美国的首都是[MASK]\"\n",
    "\n",
    "fill_mask(text, top_k=1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "id": "7605103b-47b2-4c39-b9ad-bd2cf81dc4c8",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[{'score': 0.9911912083625793,\n",
       "  'token': 3791,\n",
       "  'token_str': '法',\n",
       "  'sequence': '巴 黎 是 法 国 的 首 都 。'}]"
      ]
     },
     "execution_count": 9,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "text = \"巴黎是[MASK]国的首都。\"\n",
    "fill_mask(text, top_k=1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "id": "e5ca21ff-76ee-4447-8887-25371f9cec85",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[{'score': 0.7596911191940308,\n",
       "  'token': 8043,\n",
       "  'token_str': '？',\n",
       "  'sequence': '美 国 的 首 都 是 ？'},\n",
       " {'score': 0.21126744151115417,\n",
       "  'token': 511,\n",
       "  'token_str': '。',\n",
       "  'sequence': '美 国 的 首 都 是 。'},\n",
       " {'score': 0.026834219694137573,\n",
       "  'token': 8013,\n",
       "  'token_str': '！',\n",
       "  'sequence': '美 国 的 首 都 是 ！'}]"
      ]
     },
     "execution_count": 10,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "text = \"美国的首都是[MASK]\"\n",
    "fill_mask(text, top_k=3)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "id": "843ae228-7063-4886-bfb6-f9a6173f341e",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[[{'score': 0.5740261673927307,\n",
       "   'token': 5294,\n",
       "   'token_str': '纽',\n",
       "   'sequence': '[CLS] 美 国 的 首 都 是 纽 [MASK] [MASK] [SEP]'}],\n",
       " [{'score': 0.4926738142967224,\n",
       "   'token': 5276,\n",
       "   'token_str': '约',\n",
       "   'sequence': '[CLS] 美 国 的 首 都 是 [MASK] 约 [MASK] [SEP]'}],\n",
       " [{'score': 0.9353252053260803,\n",
       "   'token': 511,\n",
       "   'token_str': '。',\n",
       "   'sequence': '[CLS] 美 国 的 首 都 是 [MASK] [MASK] 。 [SEP]'}]]"
      ]
     },
     "execution_count": 14,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "text = \"美国的首都是[MASK][MASK][MASK]\"\n",
    "\n",
    "fill_mask(text, top_k=1)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "f4420914-a9b4-49db-b75b-c05beca89f0f",
   "metadata": {},
   "source": [
    "#### 思考：sequence 中出现的 [CLS] 和 [SEP] 是什么？"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1d848e51-7ae3-4ed1-9a0d-3432d85228f7",
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "cell_type": "markdown",
   "id": "1b8dc28a-e5d3-486e-963d-202798bb1db7",
   "metadata": {},
   "source": [
    "## 使用 AutoClass 高效管理 `Tokenizer` 和 `Model`\n",
    "\n",
    "通常，您想要使用的模型（网络架构）可以从您提供给 `from_pretrained()` 方法的预训练模型的名称或路径中推测出来。\n",
    "\n",
    "AutoClasses就是为了帮助用户完成这个工作，以便根据`预训练权重/配置文件/词汇表的名称/路径自动检索相关模型`。\n",
    "\n",
    "比如手动加载`bert-base-chinese`模型以及对应的 `tokenizer` 方法如下：\n",
    "\n",
    "```python\n",
    "from transformers import AutoTokenizer, AutoModel\n",
    "\n",
    "tokenizer = AutoTokenizer.from_pretrained(\"bert-base-chinese\")\n",
    "model = AutoModel.from_pretrained(\"bert-base-chinese\")\n",
    "```\n",
    "\n",
    "以下是我们实际操作和演示：\n",
    "\n",
    "### 使用 `from_pretrained` 方法加载指定 Model 和 Tokenizer "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "id": "c18dbf7d-7010-4f61-8c78-ec5ebfd4a49b",
   "metadata": {},
   "outputs": [],
   "source": [
    "from transformers import AutoTokenizer, AutoModel\n",
    "\n",
    "model_name = \"bert-base-chinese\"\n",
    "\n",
    "tokenizer = AutoTokenizer.from_pretrained(model_name)\n",
    "model = AutoModel.from_pretrained(model_name)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "b4abfde6-b493-40af-b09b-4533dc6450ab",
   "metadata": {},
   "source": [
    "#### 使用 BERT Tokenizer 编码文本\n",
    "\n",
    "编码 (Encoding) 过程包含两个步骤：\n",
    "\n",
    "- 分词：使用分词器按某种策略将文本切分为 tokens；\n",
    "- 映射：将 tokens 转化为对应的 token IDs。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 16,
   "id": "be846040-a2f2-485a-a1f3-a89e8f762a28",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "['美', '国', '的', '首', '都', '是', '华', '盛', '顿', '特', '区']\n"
     ]
    }
   ],
   "source": [
    "# 第一步：分词\n",
    "sequence = \"美国的首都是华盛顿特区\"\n",
    "tokens = tokenizer.tokenize(sequence)\n",
    "print(tokens)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 17,
   "id": "3870d69e-0efb-4807-b23a-7092be8e8829",
   "metadata": {},
   "outputs": [],
   "source": [
    "# 第二步：映射\n",
    "token_ids = tokenizer.convert_tokens_to_ids(tokens)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 18,
   "id": "c2c57682-1604-49e1-9b74-ff4244c1be21",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[5401, 1744, 4638, 7674, 6963, 3221, 1290, 4670, 7561, 4294, 1277]\n"
     ]
    }
   ],
   "source": [
    "print(token_ids)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "76db0ce9-67cf-4cf8-b85c-b88e3fbb63a6",
   "metadata": {},
   "source": [
    "#### 使用 Tokenizer.encode 方法端到端处理\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 19,
   "id": "6b804143-de27-4fad-bb94-90cfff2d503c",
   "metadata": {},
   "outputs": [],
   "source": [
    "token_ids_e2e = tokenizer.encode(sequence)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 20,
   "id": "06235807-6d1e-4c4f-be68-1d775e8d4164",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "[101, 5401, 1744, 4638, 7674, 6963, 3221, 1290, 4670, 7561, 4294, 1277, 102]"
      ]
     },
     "execution_count": 20,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "token_ids_e2e"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "5d2ff88a-3a6d-4dbb-9f16-8bc82d0edb06",
   "metadata": {},
   "source": [
    "#### 思考：为什么前后新增了 101 和 102？"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "id": "e2ef8588-27cd-4380-a4dd-45f762560067",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "'美 国 的 首 都 是 华 盛 顿 特 区'"
      ]
     },
     "execution_count": 21,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "tokenizer.decode(token_ids)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 22,
   "id": "57842742-83ef-41d2-a1bc-fe0792885edb",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "'[CLS] 美 国 的 首 都 是 华 盛 顿 特 区 [SEP]'"
      ]
     },
     "execution_count": 22,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "tokenizer.decode(token_ids_e2e)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "b67fb249-a512-44ae-8a8d-63d12e1b1537",
   "metadata": {},
   "source": [
    "#### 编解码多段文本"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 23,
   "id": "b339389e-50e6-4aeb-8262-0c9325690dfb",
   "metadata": {},
   "outputs": [],
   "source": [
    "sequence_batch = [\"美国的首都是华盛顿特区\", \"中国的首都是北京\"]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 24,
   "id": "b22b5448-3028-4723-a568-cb7089eedc22",
   "metadata": {},
   "outputs": [],
   "source": [
    "token_ids_batch = tokenizer.encode(sequence_batch)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 25,
   "id": "75cd9676-afc0-46dd-8a0b-315bfb1d4b7f",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "'[CLS] 美 国 的 首 都 是 华 盛 顿 特 区 [SEP] 中 国 的 首 都 是 北 京 [SEP]'"
      ]
     },
     "execution_count": 25,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "tokenizer.decode(token_ids_batch)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "48206a82-0a62-4db4-9550-3b5db83f65ee",
   "metadata": {},
   "source": [
    "![](docs/images/bert_pretrain.png)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "790f3f7c-da00-40fd-a9a8-2cf31164c638",
   "metadata": {},
   "source": [
    "### 实操建议：直接使用 tokenizer.\\_\\_call\\_\\_ 方法完成文本编码 + 特殊编码补全\n",
    "\n",
    "编码后返回结果：\n",
    "\n",
    "```json\n",
    "input_ids: token_ids\n",
    "token_type_ids: token_id 归属的句子编号\n",
    "attention_mask: 指示哪些token需要被关注（注意力机制）\n",
    "```"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 30,
   "id": "3ce65811-a559-46da-bb24-1550339dd291",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "{'input_ids': [101, 5401, 1744, 4638, 7674, 6963, 3221, 1290, 4670, 7561, 4294, 1277, 102, 704, 1744, 4638, 7674, 6963, 3221, 1266, 776, 102], 'token_type_ids': [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1], 'attention_mask': [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]}\n"
     ]
    }
   ],
   "source": [
    "embedding_batch = tokenizer(\"美国的首都是华盛顿特区\", \"中国的首都是北京\")\n",
    "print(embedding_batch)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 31,
   "id": "fd553fdf-9402-46c8-9e86-ee87c4c4a644",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "input_ids: [101, 5401, 1744, 4638, 7674, 6963, 3221, 1290, 4670, 7561, 4294, 1277, 102, 704, 1744, 4638, 7674, 6963, 3221, 1266, 776, 102]\n",
      "\n",
      "token_type_ids: [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1]\n",
      "\n",
      "attention_mask: [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]\n",
      "\n"
     ]
    }
   ],
   "source": [
    "# 优化下输出结构\n",
    "for key, value in embedding_batch.items():\n",
    "    print(f\"{key}: {value}\\n\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "8e2eb06b-582e-474d-b273-9129f2a5e3f7",
   "metadata": {},
   "source": [
    "### 添加新 Token\n",
    "\n",
    "当出现了词表或嵌入空间中不存在的新Token，需要使用 Tokenizer 将其添加到词表中。 Transformers 库提供了两种不同方法：\n",
    "\n",
    "- add_tokens: 添加常规的正文文本 Token，以追加（append）的方式添加到词表末尾。\n",
    "- add_special_tokens: 添加特殊用途的 Token，优先在已有特殊词表中选择（`bos_token, eos_token, unk_token, sep_token, pad_token, cls_token, mask_token`）。如果预定义均不满足，则都添加到`additional_special_tokens`。\n",
    "\n",
    "#### 添加常规 Token\n",
    "\n",
    "先查看已有词表，确保新添加的 Token 不在词表中："
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 32,
   "id": "051afd4f-371c-4587-8dce-bc50c2b7d926",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "21128"
      ]
     },
     "execution_count": 32,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "len(tokenizer.vocab.keys())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 35,
   "id": "877756c3-15e8-4837-9ab2-2d4378a7be16",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "##閃: 20329\n",
      "要: 6206\n",
      "淮: 3917\n",
      "sports: 10646\n",
      "疑: 4542\n",
      "193: 10185\n",
      "##倉: 13999\n",
      "##艰: 18737\n",
      "碳: 4823\n",
      "hotel: 8462\n"
     ]
    }
   ],
   "source": [
    "from itertools import islice\n",
    "\n",
    "# 使用 islice 查看词表部分内容\n",
    "for key, value in islice(tokenizer.vocab.items(), 10):\n",
    "    print(f\"{key}: {value}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 36,
   "id": "27763ad6-16b2-4653-b3f2-03e1240e9178",
   "metadata": {},
   "outputs": [],
   "source": [
    "new_tokens = [\"天干\", \"地支\"]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 37,
   "id": "3702b14a-dd84-4130-883a-f202bb31ef98",
   "metadata": {},
   "outputs": [],
   "source": [
    "# 将集合作差结果添加到词表中\n",
    "new_tokens = set(new_tokens) - set(tokenizer.vocab.keys())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 38,
   "id": "1abcee64-ecce-459d-87a1-08f5c5a18260",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "{'地支', '天干'}"
      ]
     },
     "execution_count": 38,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "new_tokens"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 39,
   "id": "0b5be3c5-2fb5-4a10-9eb2-abb7d795385f",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "2"
      ]
     },
     "execution_count": 39,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "tokenizer.add_tokens(list(new_tokens))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 40,
   "id": "91da172b-7f67-4509-b2fc-08272a5a5ee6",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "21130"
      ]
     },
     "execution_count": 40,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# 新增加了2个Token，词表总数由 21128 增加到 21130\n",
    "len(tokenizer.vocab.keys())"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "6fae854f-1e80-4e3b-b289-a0cb3b28bf28",
   "metadata": {},
   "source": [
    "#### 添加特殊Token（审慎操作）"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 41,
   "id": "930936d7-958f-458e-896a-36ab1bee9f29",
   "metadata": {},
   "outputs": [],
   "source": [
    "new_special_token = {\"sep_token\": \"NEW_SPECIAL_TOKEN\"}"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 42,
   "id": "2cd63af1-64d5-4ff1-aa31-c7b6dab3d8a1",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "1"
      ]
     },
     "execution_count": 42,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "tokenizer.add_special_tokens(new_special_token)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 43,
   "id": "97917514-889a-4f77-9d01-7d4dea561a94",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "21131"
      ]
     },
     "execution_count": 43,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# 新增加了1个特殊Token，词表总数由 21128 增加到 21131\n",
    "len(tokenizer.vocab.keys())"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "49f5e0dd-25e3-4530-b805-96736ba80da7",
   "metadata": {},
   "source": [
    "### 使用 `save_pretrained` 方法保存指定 Model 和 Tokenizer \n",
    "\n",
    "借助 `AutoClass` 的设计理念，保存 Model 和 Tokenizer 的方法也相当高效便捷。\n",
    "\n",
    "假设我们对`bert-base-chinese`模型以及对应的 `tokenizer` 做了修改，并更名为`new-bert-base-chinese`，方法如下：\n",
    "\n",
    "```python\n",
    "tokenizer.save_pretrained(\"./models/new-bert-base-chinese\")\n",
    "model.save_pretrained(\"./models/new-bert-base-chinese\")\n",
    "```\n",
    "\n",
    "保存 Tokenizer 会在指定路径下创建以下文件：\n",
    "- tokenizer.json: Tokenizer 元数据文件；\n",
    "- special_tokens_map.json: 特殊字符映射关系配置文件；\n",
    "- tokenizer_config.json: Tokenizer 基础配置文件，存储构建 Tokenizer 需要的参数；\n",
    "- vocab.txt: 词表文件；\n",
    "- added_tokens.json: 单独存放新增 Tokens 的配置文件。\n",
    "\n",
    "保存 Model 会在指定路径下创建以下文件：\n",
    "- config.json：模型配置文件，存储模型结构参数，例如 Transformer 层数、特征空间维度等；\n",
    "- pytorch_model.bin：又称为 state dictionary，存储模型的权重。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 44,
   "id": "40adadd6-3f05-4345-92a1-ef93cffb8773",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "('./models/new-bert-base-chinese/tokenizer_config.json',\n",
       " './models/new-bert-base-chinese/special_tokens_map.json',\n",
       " './models/new-bert-base-chinese/vocab.txt',\n",
       " './models/new-bert-base-chinese/added_tokens.json',\n",
       " './models/new-bert-base-chinese/tokenizer.json')"
      ]
     },
     "execution_count": 44,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "tokenizer.save_pretrained(\"./models/new-bert-base-chinese\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 46,
   "id": "e2fbb7cd-16c7-4f83-a3a0-006caf083d97",
   "metadata": {},
   "outputs": [],
   "source": [
    "model.save_pretrained(\"./models/new-bert-base-chinese\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4509a3f5-fc80-4d7f-96ac-fae6e2e4d2ca",
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.5"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
